{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Calling Model with Executor\n",
    "This tutorial shows training and evaluating model with `Executor` and `Loader`.\n",
    "- Training `model`\n",
    "- Feed `Loader` into `model`\n",
    "\n",
    "In the tutorial 1 we shows how to inference images via a specific model.\n",
    "But it seems a little complicated:\n",
    "- What's the argument of `eval` function?\n",
    "- What's the data type it returns?\n",
    "\n",
    "Therefore we provide an easier way to call models, as we named `Executor`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [
    {
     "name": "stderr",
     "text": [
      "2020-02-11 22:45:13,229 INFO: Total params: 57281\n",
      "2020-02-11 22:45:13,231 INFO: Fitting: [SRCNN]\n",
      "100%|##########| 10/10 [00:01<00:00,  6.34batch/s, loss=00.00814, image=00.00814]\n",
      "2020-02-11 22:45:14,837 INFO: Training [SRCNN] finished.\n"
     ],
     "output_type": "stream"
    },
    {
     "name": "stdout",
     "text": [
      "['epochs', 'batch_shape', 'steps', 'val_steps', 'lr', 'lr_schedule', 'memory_limit', 'inference_results_hooks', 'validate_every_n_epoch', 'traced_val', 'ensemble', 'cuda', 'map_location', 'train_loader', 'val_loader', 'epoch', 'avg_meas']\n",
      "| 2020-02-11 22:45:13 | Epoch: 1/1 | LR: 0.001 |\n",
      "| Epoch average loss = 0.053390 |\n",
      "| Epoch average image = 0.053390 |\n"
     ],
     "output_type": "stream"
    }
   ],
   "source": [
    "from VSR.DataLoader import Dataset, Loader\n",
    "from VSR.Model import get_model\n",
    "\n",
    "srcnn_builder = get_model('srcnn')\n",
    "srcnn = srcnn_builder(scale=3, channel=1)\n",
    "try:\n",
    "    srcnn.load(\"srcnn.pth\")\n",
    "except FileNotFoundError:\n",
    "    pass\n",
    "data = Dataset('../../Tests/data').include('*.png')\n",
    "loader = Loader(data, scale=3, threads=8)\n",
    "# convert to grayscale image\n",
    "loader.set_color_space('hr', 'L')\n",
    "loader.set_color_space('lr', 'L')\n",
    "with srcnn.get_executor(root=None) as t:\n",
    "    config = t.query_config({})\n",
    "    print(list(config.keys()))\n",
    "    config.epochs = 1\n",
    "    config.steps = 10\n",
    "    config.batch_shape = [4, 1, 16, 16]\n",
    "    config.lr = 1e-3\n",
    "    t.fit([loader, None], config)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": false
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "The method `get_executor(root)` builds an model executor with a working directory.\n",
    "If root is not set, the executor runs without saving models.\n",
    "\n",
    "Using `with` statement to construct a running environment for it.\n",
    "\n",
    "To config the executor, calling `query_config({})` to get the default configuration.\n",
    "\n",
    "To train the model, call `fit` method. The first argument is a list of two loaders,\n",
    "the first one is the loader for training and the second loader is for validating (can be None).\n",
    "\n",
    "To evaluate or inference the model, call `benchmark` or `infer` method."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [
    {
     "name": "stderr",
     "text": [
      "2020-02-11 22:46:43,370 INFO: Total params: 57281\n",
      "2020-02-11 22:46:43,381 INFO: Inferring [srcnn] at epoch 0\n",
      "Infer: 100%|##########| 16/16 [00:19<00:00,  1.24s/it]\n"
     ],
     "output_type": "stream"
    },
    {
     "name": "stdout",
     "text": [
      "1 (1, 1, 384, 510)\n",
      "['img0']\n",
      "1 (1, 1, 384, 510)\n",
      "['img1']\n",
      "1 (1, 1, 375, 1242)\n",
      "['c_10']\n",
      "1 (1, 1, 375, 1242)\n",
      "['c_11']\n",
      "1 (1, 1, 375, 1242)\n",
      "['f_01']\n",
      "1 (1, 1, 126, 126)\n",
      "['img_003_SRF_2_LR']\n",
      "1 (1, 1, 255, 255)\n",
      "['img_001_SRF_2_LR']\n",
      "1 (1, 1, 138, 138)\n",
      "['img_004_SRF_2_LR']\n",
      "1 (1, 1, 144, 144)\n",
      "['img_002_SRF_2_LR']\n",
      "1 (1, 1, 270, 480)\n",
      "['xiuxian001']\n",
      "1 (1, 1, 171, 114)\n",
      "['img_005_SRF_2_LR']\n",
      "1 (1, 1, 135, 240)\n",
      "['xiuxian002']\n",
      "1 (1, 1, 135, 240)\n",
      "['xiuxian001']\n",
      "1 (1, 1, 135, 240)\n",
      "['xiuxian003']\n",
      "1 (1, 1, 270, 480)\n",
      "['xiuxian002']\n",
      "1 (1, 1, 270, 480)\n",
      "['xiuxian003']\n"
     ],
     "output_type": "stream"
    }
   ],
   "source": [
    "from VSR.DataLoader import Dataset, Loader\n",
    "from VSR.Model import get_model\n",
    "\n",
    "def my_hook(outputs: list, names: list):\n",
    "    print(len(outputs), outputs[0].shape)\n",
    "    print(names)\n",
    "\n",
    "srcnn_builder = get_model('srcnn')\n",
    "srcnn = srcnn_builder(scale=3, channel=1)\n",
    "try:\n",
    "    srcnn.load(\"srcnn.pth\")\n",
    "except FileNotFoundError:\n",
    "    pass\n",
    "data = Dataset('../../Tests/data').include('*.png')\n",
    "loader = Loader(data, scale=3, threads=8)\n",
    "# convert to grayscale image\n",
    "loader.set_color_space('hr', 'L')\n",
    "loader.set_color_space('lr', 'L')\n",
    "with srcnn.get_executor(root=None) as t:\n",
    "    config = t.query_config({})\n",
    "    config.inference_results_hooks = [my_hook]\n",
    "    # t.benchmark(loader, config)\n",
    "    t.infer(loader, config)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": false
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "source": [],
    "metadata": {
     "collapsed": false
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}